{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/Bonus_UnsupervisedLearning/Tutorial2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neuromatch Academy: Week 3, Day 5, Tutorial 2\n",
    "# Deep Learning 2: Autoencoder extensions\n",
    "\n",
    "__Content creators:__ Marco Brigham and the [CCNSS](https://www.ccnss.org/) team (2014-2018)\n",
    "\n",
    "__Content reviewers:__ Itzel Olivos, Karen Schroeder, Karolina Stosio, Kshitij Dwivedi, Spiros Chavlis, Michael Waskom"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Tutorial Objectives\n",
    "\n",
    "## Architecture\n",
    "How can we improve the internal representation of shallow autoencoder with 2D bottleneck layer? \n",
    "\n",
    "We may try the following architecture changes:\n",
    "* Introducing additional hidden layers\n",
    "* Wrapping latent space as a sphere\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![Deep ANN autoencoder](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/ae-ann-3h.png)\n",
    "\n",
    "Adding hidden layers increases the number of learnable parameters to better use non-linear operations in encoding/decoding. Spherical geometry of latent space forces the network to use these additional degrees of freedom more efficiently.\n",
    "\n",
    "Let's dive deeper into the technical aspects of autoencoders and improve their internal representations to reach the levels required for the *MNIST cognitive task*.\n",
    "\n",
    "In this tutorial, you will:\n",
    "- Increase the capacity of the network by introducing additional hidden layers\n",
    "- Understand the effect of constraints in the geometry of latent space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:03.111592Z",
     "iopub.status.busy": "2021-05-25T01:03:03.111003Z",
     "iopub.status.idle": "2021-05-25T01:03:03.174833Z",
     "shell.execute_reply": "2021-05-25T01:03:03.174288Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Video 1: Extensions\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"pgkrU9UqXiU\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Setup\n",
    "Please execute the cell(s) below to initialize the notebook environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:03.178122Z",
     "iopub.status.busy": "2021-05-25T01:03:03.177575Z",
     "iopub.status.idle": "2021-05-25T01:03:04.162427Z",
     "shell.execute_reply": "2021-05-25T01:03:04.161322Z"
    }
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "\n",
    "from sklearn.datasets import fetch_openml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:04.170298Z",
     "iopub.status.busy": "2021-05-25T01:03:04.169730Z",
     "iopub.status.idle": "2021-05-25T01:03:13.361936Z",
     "shell.execute_reply": "2021-05-25T01:03:13.361274Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Figure settings\n",
    "!pip install plotly --quiet\n",
    "import plotly.graph_objects as go\n",
    "from plotly.colors import qualitative\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:13.382390Z",
     "iopub.status.busy": "2021-05-25T01:03:13.370998Z",
     "iopub.status.idle": "2021-05-25T01:03:13.419046Z",
     "shell.execute_reply": "2021-05-25T01:03:13.418058Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Helper functions\n",
    "\n",
    "\n",
    "def downloadMNIST():\n",
    "  \"\"\"\n",
    "  Download MNIST dataset and transform it to torch.Tensor\n",
    "\n",
    "  Args:\n",
    "    None\n",
    "\n",
    "  Returns:\n",
    "    x_train : training images (torch.Tensor) (60000, 28, 28)\n",
    "    x_test  : test images (torch.Tensor) (10000, 28, 28)\n",
    "    y_train : training labels (torch.Tensor) (60000, )\n",
    "    y_train : test labels (torch.Tensor) (10000, )\n",
    "  \"\"\"\n",
    "  X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame = False)\n",
    "  # Trunk the data\n",
    "  n_train = 60000\n",
    "  n_test = 10000\n",
    "\n",
    "  train_idx = np.arange(0, n_train)\n",
    "  test_idx = np.arange(n_train, n_train + n_test)\n",
    "\n",
    "  x_train, y_train = X[train_idx], y[train_idx]\n",
    "  x_test, y_test = X[test_idx], y[test_idx]\n",
    "\n",
    "  # Transform np.ndarrays to torch.Tensor\n",
    "  x_train = torch.from_numpy(np.reshape(x_train,\n",
    "                                        (len(x_train),\n",
    "                                         28, 28)).astype(np.float32))\n",
    "  x_test = torch.from_numpy(np.reshape(x_test,\n",
    "                                       (len(x_test),\n",
    "                                        28, 28)).astype(np.float32))\n",
    "\n",
    "  y_train = torch.from_numpy(y_train.astype(int))\n",
    "  y_test = torch.from_numpy(y_test.astype(int))\n",
    "\n",
    "  return (x_train, y_train, x_test, y_test)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_uniform(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming uniform distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming uniform distribution\n",
    "    nn.init.kaiming_uniform_(layer.weight.data)\n",
    "\n",
    "\n",
    "def init_weights_kaiming_normal(layer):\n",
    "  \"\"\"\n",
    "  Initializes weights from linear PyTorch layer\n",
    "  with kaiming normal distribution.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "  # check for linear PyTorch layer\n",
    "  if isinstance(layer, nn.Linear):\n",
    "    # initialize weights with kaiming normal distribution\n",
    "    nn.init.kaiming_normal_(layer.weight.data)\n",
    "\n",
    "\n",
    "def get_layer_weights(layer):\n",
    "  \"\"\"\n",
    "  Retrieves learnable parameters from PyTorch layer.\n",
    "\n",
    "  Args:\n",
    "    layer (torch.Module)\n",
    "        Pytorch layer\n",
    "\n",
    "  Returns:\n",
    "    list with learnable parameters\n",
    "  \"\"\"\n",
    "  # initialize output list\n",
    "  weights = []\n",
    "\n",
    "  # check whether layer has learnable parameters\n",
    "  if layer.parameters():\n",
    "    # copy numpy array representation of each set of learnable parameters\n",
    "    for item in layer.parameters():\n",
    "      weights.append(item.detach().numpy())\n",
    "\n",
    "  return weights\n",
    "\n",
    "\n",
    "def print_parameter_count(net):\n",
    "  \"\"\"\n",
    "  Prints count of learnable parameters per layer from PyTorch network.\n",
    "\n",
    "  Args:\n",
    "    net (torch.Sequential)\n",
    "        Pytorch network\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  params_n = 0\n",
    "\n",
    "  # loop all layers in network\n",
    "  for layer_idx, layer in enumerate(net):\n",
    "\n",
    "    # retrieve learnable parameters\n",
    "    weights = get_layer_weights(layer)\n",
    "    params_layer_n = 0\n",
    "\n",
    "    # loop list of learnable parameters and count them\n",
    "    for params in weights:\n",
    "      params_layer_n += params.size\n",
    "\n",
    "    params_n += params_layer_n\n",
    "    print(f'{layer_idx}\\t {params_layer_n}\\t {layer}')\n",
    "\n",
    "  print(f'\\nTotal:\\t {params_n}')\n",
    "\n",
    "\n",
    "def eval_mse(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates mean square error (MSE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    MSE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "      criterion = nn.MSELoss()\n",
    "      loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def eval_bce(y_pred, y_true):\n",
    "  \"\"\"\n",
    "  Evaluates binary cross-entropy (BCE) between y_pred and y_true\n",
    "\n",
    "  Args:\n",
    "    y_pred (torch.Tensor)\n",
    "        prediction samples\n",
    "\n",
    "    v (numpy array of floats)\n",
    "        ground truth samples\n",
    "\n",
    "  Returns:\n",
    "    BCE(y_pred, y_true)\n",
    "  \"\"\"\n",
    "\n",
    "  with torch.no_grad():\n",
    "    criterion = nn.BCELoss()\n",
    "    loss = criterion(y_pred, y_true)\n",
    "\n",
    "  return float(loss)\n",
    "\n",
    "\n",
    "def plot_row(images, show_n=10, image_shape=None):\n",
    "  \"\"\"\n",
    "  Plots rows of images from list of iterables (iterables: list, numpy array\n",
    "  or torch.Tensor). Also accepts single iterable.\n",
    "  Randomly selects images in each list element if item count > show_n.\n",
    "\n",
    "  Args:\n",
    "    images (iterable or list of iterables)\n",
    "        single iterable with images, or list of iterables\n",
    "\n",
    "    show_n (integer)\n",
    "        maximum number of images per row\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image if vectorized form\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if not isinstance(images, (list, tuple)):\n",
    "    images = [images]\n",
    "\n",
    "  for items_idx, items in enumerate(images):\n",
    "\n",
    "    items = np.array(items)\n",
    "    if items.ndim == 1:\n",
    "      items = np.expand_dims(items, axis=0)\n",
    "\n",
    "    if len(items) > show_n:\n",
    "      selected = np.random.choice(len(items), show_n, replace=False)\n",
    "      items = items[selected]\n",
    "\n",
    "    if image_shape is not None:\n",
    "      items = items.reshape([-1]+list(image_shape))\n",
    "\n",
    "    plt.figure(figsize=(len(items) * 1.5, 2))\n",
    "    for image_idx, image in enumerate(items):\n",
    "\n",
    "      plt.subplot(1, len(items), image_idx + 1)\n",
    "      plt.imshow(image, cmap='gray', vmin=image.min(), vmax=image.max())\n",
    "      plt.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "\n",
    "def to_s2(u):\n",
    "  \"\"\"\n",
    "  Projects 3D coordinates to spherical coordinates (theta, phi) surface of\n",
    "  unit sphere S2.\n",
    "  theta: [0, pi]\n",
    "  phi: [-pi, pi]\n",
    "\n",
    "  Args:\n",
    "    u (list, numpy array or torch.Tensor of floats)\n",
    "        3D coordinates\n",
    "\n",
    "  Returns:\n",
    "    Sperical coordinates (theta, phi) on surface of unit sphere S2.\n",
    "  \"\"\"\n",
    "\n",
    "  x, y, z = (u[:, 0], u[:, 1], u[:, 2])\n",
    "  r = np.sqrt(x**2 + y**2 + z**2)\n",
    "  theta = np.arccos(z / r)\n",
    "  phi = np.arctan2(x, y)\n",
    "\n",
    "  return np.array([theta, phi]).T\n",
    "\n",
    "\n",
    "def to_u3(s):\n",
    "  \"\"\"\n",
    "  Converts from 2D coordinates on surface of unit sphere S2 to 3D coordinates\n",
    "  (on surface of S2), i.e. (theta, phi) ---> (1, theta, phi).\n",
    "\n",
    "  Args:\n",
    "    s (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates on unit sphere S_2\n",
    "\n",
    "  Returns:\n",
    "    3D coordinates on surface of unit sphere S_2\n",
    "  \"\"\"\n",
    "\n",
    "  theta, phi = (s[:, 0], s[:, 1])\n",
    "  x = np.sin(theta) * np.sin(phi)\n",
    "  y = np.sin(theta) * np.cos(phi)\n",
    "  z = np.cos(theta)\n",
    "\n",
    "  return np.array([x, y, z]).T\n",
    "\n",
    "\n",
    "def xy_lim(x):\n",
    "  \"\"\"\n",
    "  Return arguments for plt.xlim and plt.ylim calculated from minimum\n",
    "  and maximum of x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        data to be plotted\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  x_min = np.min(x, axis=0)\n",
    "  x_max = np.max(x, axis=0)\n",
    "\n",
    "  x_min = x_min - np.abs(x_max - x_min) * 0.05 - np.finfo(float).eps\n",
    "  x_max = x_max + np.abs(x_max - x_min) * 0.05 + np.finfo(float).eps\n",
    "\n",
    "  return [x_min[0], x_max[0]], [x_min[1], x_max[1]]\n",
    "\n",
    "\n",
    "def plot_generative(x, decoder_fn, image_shape, n_row=16, s2=False):\n",
    "  \"\"\"\n",
    "  Plots images reconstructed by decoder_fn from a 2D grid in\n",
    "  latent space that is determined by minimum and maximum values in x.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D or 3D coordinates in latent space\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    n_row (integer)\n",
    "        number of rows in grid\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if s2:\n",
    "    x = to_s2(np.array(x))\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "\n",
    "  dx = (xlim[1] - xlim[0]) / n_row\n",
    "  grid = [np.linspace(ylim[0] + dx / 2, ylim[1] - dx / 2, n_row),\n",
    "          np.linspace(xlim[0] + dx / 2, xlim[1] - dx / 2, n_row)]\n",
    "\n",
    "  canvas = np.zeros((image_shape[0] * n_row, image_shape[1] * n_row))\n",
    "\n",
    "  cmap = plt.get_cmap('gray')\n",
    "\n",
    "  for j, latent_y in enumerate(grid[0][::-1]):\n",
    "    for i, latent_x in enumerate(grid[1]):\n",
    "\n",
    "      latent = np.array([[latent_x, latent_y]], dtype=np.float32)\n",
    "\n",
    "      if s2:\n",
    "        latent = to_u3(latent)\n",
    "\n",
    "      with torch.no_grad():\n",
    "        x_decoded = decoder_fn(torch.from_numpy(latent))\n",
    "\n",
    "      x_decoded = x_decoded.reshape(image_shape)\n",
    "\n",
    "      canvas[j * image_shape[0]: (j + 1) * image_shape[0],\n",
    "             i * image_shape[1]: (i + 1) * image_shape[1]] = x_decoded\n",
    "\n",
    "  plt.imshow(canvas, cmap=cmap, vmin=canvas.min(), vmax=canvas.max())\n",
    "  plt.axis('off')\n",
    "\n",
    "\n",
    "def plot_latent(x, y, show_n=500, s2=False, fontdict=None, xy_labels=None):\n",
    "  \"\"\"\n",
    "  Plots digit class of each sample in 2D latent space coordinates.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    n_row (integer)\n",
    "        number of samples\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "    fontdict (dictionary)\n",
    "        style option for plt.text\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional list with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  if fontdict is None:\n",
    "    fontdict = {'weight': 'bold', 'size': 12}\n",
    "\n",
    "  if s2:\n",
    "    x = to_s2(np.array(x))\n",
    "\n",
    "  cmap = plt.get_cmap('tab10')\n",
    "\n",
    "  if len(x) > show_n:\n",
    "    selected = np.random.choice(len(x), show_n, replace=False)\n",
    "    x = x[selected]\n",
    "    y = y[selected]\n",
    "\n",
    "  for my_x, my_y in zip(x, y):\n",
    "    plt.text(my_x[0], my_x[1], str(int(my_y)),\n",
    "             color=cmap(int(my_y) / 10.),\n",
    "             fontdict=fontdict,\n",
    "             horizontalalignment='center',\n",
    "             verticalalignment='center',\n",
    "             alpha=0.8)\n",
    "\n",
    "  xlim, ylim = xy_lim(np.array(x))\n",
    "  plt.xlim(xlim)\n",
    "  plt.ylim(ylim)\n",
    "\n",
    "  if s2:\n",
    "    if xy_labels is None:\n",
    "      xy_labels = [r'$\\varphi$', r'$\\theta$']\n",
    "\n",
    "    plt.xticks(np.arange(0, np.pi + np.pi / 6, np.pi / 6),\n",
    "               ['0', '$\\pi/6$', '$\\pi/3$', '$\\pi/2$',\n",
    "                '$2\\pi/3$', '$5\\pi/6$', '$\\pi$'])\n",
    "    plt.yticks(np.arange(-np.pi, np.pi + np.pi / 3, np.pi / 3),\n",
    "               ['$-\\pi$', '$-2\\pi/3$', '$-\\pi/3$', '0',\n",
    "                '$\\pi/3$', '$2\\pi/3$', '$\\pi$'])\n",
    "\n",
    "  if xy_labels is None:\n",
    "    xy_labels = ['$Z_1$', '$Z_2$']\n",
    "\n",
    "  plt.xlabel(xy_labels[0])\n",
    "  plt.ylabel(xy_labels[1])\n",
    "\n",
    "\n",
    "def plot_latent_generative(x, y, decoder_fn, image_shape, s2=False,\n",
    "                           title=None, xy_labels=None):\n",
    "  \"\"\"\n",
    "  Two horizontal subplots generated with encoder map and decoder grid.\n",
    "\n",
    "  Args:\n",
    "    x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    decoder_fn (integer)\n",
    "        function returning vectorized images from 2D latent space coordinates\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "    title (string)\n",
    "        plot title\n",
    "\n",
    "    xy_labels (list)\n",
    "        optional list with [xlabel, ylabel]\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  fig = plt.figure(figsize=(12, 6))\n",
    "\n",
    "  if title is not None:\n",
    "    fig.suptitle(title, y=1.05)\n",
    "\n",
    "  ax = fig.add_subplot(121)\n",
    "  ax.set_title('Encoder map', y=1.05)\n",
    "  plot_latent(x, y, s2=s2, xy_labels=xy_labels)\n",
    "\n",
    "  ax = fig.add_subplot(122)\n",
    "  ax.set_title('Decoder grid', y=1.05)\n",
    "  plot_generative(x, decoder_fn, image_shape, s2=s2)\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "def plot_latent_3d(my_x, my_y, show_text=True, show_n=500):\n",
    "  \"\"\"\n",
    "  Plot digit class or marker in 3D latent space coordinates.\n",
    "\n",
    "  Args:\n",
    "    my_x (list, numpy array or torch.Tensor of floats)\n",
    "        2D coordinates in latent space\n",
    "\n",
    "    my_y (list, numpy array or torch.Tensor of floats)\n",
    "        digit class of each sample\n",
    "\n",
    "    show_text (boolean)\n",
    "        whether to show text\n",
    "\n",
    "    image_shape (tuple or list)\n",
    "        original shape of image\n",
    "\n",
    "    s2 (boolean)\n",
    "        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)\n",
    "\n",
    "    title (string)\n",
    "        plot title\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  layout = {'margin': {'l': 0, 'r': 0, 'b': 0, 't': 0},\n",
    "            'scene': {'xaxis': {'showspikes': False,\n",
    "                                'title': 'z1'},\n",
    "                      'yaxis': {'showspikes': False,\n",
    "                                'title': 'z2'},\n",
    "                      'zaxis': {'showspikes': False,\n",
    "                                'title': 'z3'}}\n",
    "            }\n",
    "\n",
    "  selected_idx = np.random.choice(len(my_x), show_n, replace=False)\n",
    "\n",
    "  colors = [qualitative.T10[idx] for idx in my_y[selected_idx]]\n",
    "\n",
    "  x = my_x[selected_idx, 0]\n",
    "  y = my_x[selected_idx, 1]\n",
    "  z = my_x[selected_idx, 2]\n",
    "\n",
    "  text = my_y[selected_idx]\n",
    "\n",
    "  if show_text:\n",
    "\n",
    "    trace = go.Scatter3d(x=x, y=y, z=z, text=text,\n",
    "                         mode='text',\n",
    "                         textfont={'color': colors, 'size': 12}\n",
    "                         )\n",
    "\n",
    "    layout['hovermode'] = False\n",
    "\n",
    "  else:\n",
    "\n",
    "    trace = go.Scatter3d(x=x, y=y, z=z, text=text,\n",
    "                         hoverinfo='text', mode='markers',\n",
    "                         marker={'size': 5, 'color': colors, 'opacity': 0.8}\n",
    "                         )\n",
    "\n",
    "  fig = go.Figure(data=trace, layout=layout)\n",
    "\n",
    "  fig.show()\n",
    "\n",
    "\n",
    "def runSGD(net, input_train, input_test, criterion='bce',\n",
    "           n_epochs=10, batch_size=32, verbose=False):\n",
    "  \"\"\"\n",
    "  Trains autoencoder network with stochastic gradient descent with Adam\n",
    "  optimizer and loss criterion. Train samples are shuffled, and loss is\n",
    "  displayed at the end of each opoch for both MSE and BCE. Plots training loss\n",
    "  at each minibatch (maximum of 500 randomly selected values).\n",
    "\n",
    "  Args:\n",
    "    net (torch network)\n",
    "        ANN object (nn.Module)\n",
    "\n",
    "    input_train (torch.Tensor)\n",
    "        vectorized input images from train set\n",
    "\n",
    "    input_test (torch.Tensor)\n",
    "        vectorized input images from test set\n",
    "\n",
    "    criterion (string)\n",
    "        train loss: 'bce' or 'mse'\n",
    "\n",
    "    n_epochs (boolean)\n",
    "        number of full iterations of training data\n",
    "\n",
    "    batch_size (integer)\n",
    "        number of element in mini-batches\n",
    "\n",
    "    verbose (boolean)\n",
    "        print final loss\n",
    "\n",
    "  Returns:\n",
    "    Nothing.\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize loss function\n",
    "  if criterion == 'mse':\n",
    "    loss_fn = nn.MSELoss()\n",
    "  elif criterion == 'bce':\n",
    "    loss_fn = nn.BCELoss()\n",
    "  else:\n",
    "    print('Please specify either \"mse\" or \"bce\" for loss criterion')\n",
    "\n",
    "  # Initialize SGD optimizer\n",
    "  optimizer = optim.Adam(net.parameters())\n",
    "\n",
    "  # Placeholder for loss\n",
    "  track_loss = []\n",
    "\n",
    "  print('Epoch', '\\t', 'Loss train', '\\t', 'Loss test')\n",
    "  for i in range(n_epochs):\n",
    "\n",
    "    shuffle_idx = np.random.permutation(len(input_train))\n",
    "    batches = torch.split(input_train[shuffle_idx], batch_size)\n",
    "\n",
    "    for batch in batches:\n",
    "\n",
    "      output_train = net(batch)\n",
    "      loss = loss_fn(output_train, batch)\n",
    "      optimizer.zero_grad()\n",
    "      loss.backward()\n",
    "      optimizer.step()\n",
    "\n",
    "      # Keep track of loss at each epoch\n",
    "      track_loss += [float(loss)]\n",
    "\n",
    "    loss_epoch = f'{i+1}/{n_epochs}'\n",
    "    with torch.no_grad():\n",
    "      output_train = net(input_train)\n",
    "      loss_train = loss_fn(output_train, input_train)\n",
    "      loss_epoch += f'\\t {loss_train:.4f}'\n",
    "\n",
    "      output_test = net(input_test)\n",
    "      loss_test = loss_fn(output_test, input_test)\n",
    "      loss_epoch += f'\\t\\t {loss_test:.4f}'\n",
    "\n",
    "    print(loss_epoch)\n",
    "\n",
    "  if verbose:\n",
    "    # Print loss\n",
    "    loss_mse = f'\\nMSE\\t {eval_mse(output_train, input_train):0.4f}'\n",
    "    loss_mse += f'\\t\\t {eval_mse(output_test, input_test):0.4f}'\n",
    "    print(loss_mse)\n",
    "\n",
    "    loss_bce = f'BCE\\t {eval_bce(output_train, input_train):0.4f}'\n",
    "    loss_bce += f'\\t\\t {eval_bce(output_test, input_test):0.4f}'\n",
    "    print(loss_bce)\n",
    "\n",
    "  # Plot loss\n",
    "  step = int(np.ceil(len(track_loss) / 500))\n",
    "  x_range = np.arange(0, len(track_loss), step)\n",
    "  plt.figure()\n",
    "  plt.plot(x_range, track_loss[::step], 'C0')\n",
    "  plt.xlabel('Iterations')\n",
    "  plt.ylabel('Loss')\n",
    "  plt.xlim([0, None])\n",
    "  plt.ylim([0, None])\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "class NormalizeLayer(nn.Module):\n",
    "  \"\"\"\n",
    "  pyTorch layer (nn.Module) that normalizes activations by their L2 norm.\n",
    "\n",
    "  Args:\n",
    "      None.\n",
    "\n",
    "  Returns:\n",
    "      Object inherited from nn.Module class.\n",
    "  \"\"\"\n",
    "\n",
    "  def __init__(self):\n",
    "    super().__init__()\n",
    "\n",
    "  def forward(self, x):\n",
    "    return nn.functional.normalize(x, p=2, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 1: Download  and prepare MNIST dataset\n",
    "We use the helper function `downloadMNIST` to download the dataset and transform it into `torch.Tensor` and assign train and test sets to (`x_train`, `y_train`) and (`x_test`, `y_test`).\n",
    "\n",
    "The variable `input_size` stores the length of *vectorized* versions of the images `input_train` and `input_test` for training and test images.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:13.425417Z",
     "iopub.status.busy": "2021-05-25T01:03:13.424460Z",
     "iopub.status.idle": "2021-05-25T01:03:35.519418Z",
     "shell.execute_reply": "2021-05-25T01:03:35.518549Z"
    }
   },
   "outputs": [],
   "source": [
    "# Download MNIST\n",
    "x_train, y_train, x_test, y_test = downloadMNIST()\n",
    "\n",
    "x_train = x_train / 255\n",
    "x_test = x_test / 255\n",
    "\n",
    "image_shape = x_train.shape[1:]\n",
    "\n",
    "input_size = np.prod(image_shape)\n",
    "\n",
    "input_train = x_train.reshape([-1, input_size])\n",
    "input_test = x_test.reshape([-1, input_size])\n",
    "\n",
    "test_selected_idx = np.random.choice(len(x_test), 10, replace=False)\n",
    "train_selected_idx = np.random.choice(len(x_train), 10, replace=False)\n",
    "\n",
    "print(f'shape image \\t \\t {image_shape}')\n",
    "print(f'shape input_train \\t {input_train.shape}')\n",
    "print(f'shape input_test \\t {input_test.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 2: Deeper autoencoder (2D)\n",
    "The internal representation of shallow autoencoder with 2D latent space is similar to PCA, which shows that the autoencoder is not fully leveraging non-linear capabilities to model data. Adding capacity in terms of learnable parameters takes advantage of non-linear operations in encoding/decoding to capture non-linear patterns in data.\n",
    "\n",
    "Adding hidden layers enables us to introduce additional parameters, either layerwise or depthwise. The same amount $N$ of additional parameters can be added in a single layer or distributed among several layers. Adding several hidden layers reduces the compression/decompression ratio of each layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercise 1: Build deeper autoencoder (2D)\n",
    "Implement this deeper version of the ANN autoencoder by adding four hidden layers. The number of units per layer in the encoder is the following:\n",
    "\n",
    "```\n",
    "784 -> 392 -> 64 -> 2\n",
    "```\n",
    "\n",
    "The shallow autoencoder has a compression ratio of **784:2 = 392:1**. The first additional hidden layer has a compression ratio of **2:1**,  followed by a hidden layer that sets the bottleneck compression ratio of **32:1**.\n",
    "\n",
    "The choice of hidden layer size aims to reduce the compression rate in the bottleneck layer while increasing the count of trainable parameters.  For example, if the compression rate of the first hidden layer doubles from **2:1** to **4:1**, the count of trainable parameters halves from 667K to 333K.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "This deep autoencoder's performance may be further improved by adding additional hidden layers and by increasing the count of trainable parameters in each layer. These improvements have a diminishing return due to challenges associated with training under high parameter count and depth. One option explored in the *Bonus* section is to add a first hidden layer with 2x - 3x the input size. This size increase results in millions of parameters at the cost of longer training time.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "Weight initialization is particularly important in deep networks. The availability of large datasets and weight initialization likely drove the deep learning revolution of 2010. We'll implement Kaiming normal as follows:\n",
    "```\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "```\n",
    "\n",
    "**Instructions:**\n",
    "* Add four additional layers and activation functions to the network\n",
    "* Adjust the definitions of `encoder` and `decoder`\n",
    "* Check learnable parameter count for this autoencoder by executing the last cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:35.525004Z",
     "iopub.status.busy": "2021-05-25T01:03:35.524477Z",
     "iopub.status.idle": "2021-05-25T01:03:35.533696Z",
     "shell.execute_reply": "2021-05-25T01:03:35.534162Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 2\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "    #################################################\n",
    "    ## TODO for students: add layers to build deeper autoencoder\n",
    "    #################################################\n",
    "    # Add activation function\n",
    "    # ...,\n",
    "    # Add another layer\n",
    "    # nn.Linear(..., ...),\n",
    "    # Add activation function\n",
    "    # ...,\n",
    "    # Add another layer\n",
    "    # nn.Linear(..., ...),\n",
    "    # Add activation function\n",
    "    # ...,\n",
    "    # Add another layer\n",
    "    # nn.Linear(..., ...),\n",
    "    # Add activation function\n",
    "    # ...,\n",
    "    # Add another layer\n",
    "    # nn.Linear(..., ...),\n",
    "    # Add activation function\n",
    "    # ....\n",
    "    )\n",
    "\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}\\n')\n",
    "\n",
    "# Adjust the value n_l to split your model correctly\n",
    "# n_l = ...\n",
    "\n",
    "# uncomment when you fill the code\n",
    "# encoder = model[:n_l]\n",
    "# decoder = model[n_l:]\n",
    "# print(f'Encoder \\n\\n {encoder}\\n')\n",
    "# print(f'Decoder \\n\\n {decoder}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:35.540399Z",
     "iopub.status.busy": "2021-05-25T01:03:35.539854Z",
     "iopub.status.idle": "2021-05-25T01:03:35.553848Z",
     "shell.execute_reply": "2021-05-25T01:03:35.553416Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "encoding_size = 2\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "    # Add activation function\n",
    "    nn.PReLU(),\n",
    "    # Add another layer\n",
    "    nn.Linear(encoding_size * 32, encoding_size),\n",
    "    # Add activation function\n",
    "    nn.PReLU(),\n",
    "    # Add another layer\n",
    "    nn.Linear(encoding_size, encoding_size * 32),\n",
    "    # Add activation function\n",
    "    nn.PReLU(),\n",
    "    # Add another layer\n",
    "    nn.Linear(encoding_size * 32, int(input_size / 2)),\n",
    "    # Add activation function\n",
    "    nn.PReLU(),\n",
    "    # Add another layer\n",
    "    nn.Linear(int(input_size / 2), input_size),\n",
    "    # Add activation function\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}\\n')\n",
    "\n",
    "# Adjust the value n_l to split your model correctly\n",
    "n_l = 6\n",
    "\n",
    "# uncomment when you fill the code\n",
    "encoder = model[:n_l]\n",
    "decoder = model[n_l:]\n",
    "print(f'Encoder \\n\\n {encoder}\\n')\n",
    "print(f'Decoder \\n\\n {decoder}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Helper function:** `print_parameter_count`\n",
    "\n",
    "Please uncomment the line below to inspect this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:35.557631Z",
     "iopub.status.busy": "2021-05-25T01:03:35.557183Z",
     "iopub.status.idle": "2021-05-25T01:03:35.559415Z",
     "shell.execute_reply": "2021-05-25T01:03:35.558967Z"
    }
   },
   "outputs": [],
   "source": [
    "# help(print_parameter_count)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the autoencoder\n",
    "\n",
    "Train the network for `n_epochs=10` epochs with `batch_size=128`, and observe how the internal representation successfully captures additional digit classes.\n",
    "\n",
    "The encoder map shows well-separated clusters that correspond to the associated digits in the decoder grid. The decoder grid also shows that the network is robust to digit skewness, i.e., digits leaning to the left or the right are recognized in the same digit class.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:03:35.573121Z",
     "iopub.status.busy": "2021-05-25T01:03:35.572556Z",
     "iopub.status.idle": "2021-05-25T01:04:18.419813Z",
     "shell.execute_reply": "2021-05-25T01:04:18.420266Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "runSGD(model, input_train, input_test, n_epochs=n_epochs,\n",
    "       batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:04:18.424721Z",
     "iopub.status.busy": "2021-05-25T01:04:18.424163Z",
     "iopub.status.idle": "2021-05-25T01:04:21.846103Z",
     "shell.execute_reply": "2021-05-25T01:04:21.846576Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Section 3: Spherical latent space\n",
    "\n",
    "The previous architecture generates representations that typically spread in different directions from coordinate $(z_1, z_2)=(0,0)$. This effect is due to the initialization of weights distributed randomly around `0`.\n",
    "\n",
    "Adding a third unit to the bottleneck layer defines a coordinate $(z_1, z_2, z_3)$ in 3D space. The latent space from such a network will still spread out from $(z_1, z_2, z_3)=(0, 0, 0)$.\n",
    "\n",
    "Collapsing the latent space on the surface of a sphere removes the possibility of spreading indefinitely from the origin $(0, 0, 0)$ in any direction since this will eventually lead back to the origin. This constraint generates a representation that fills the surface of the sphere.\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "![Unit sphere S2](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/unit_sphere.png)\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "\n",
    "Projecting to the surface of the sphere is implemented by dividing the coordinates $(z_1, z_2, z_3)$ by their $L_2$ norm.\n",
    "\n",
    "$(z_1, z_2, z_3)\\longmapsto (s_1, s_2, s_3)=(z_1, z_2, z_3)/\\|(z_1, z_2, z_3)\\|_2=(z_1, z_2, z_3)/ \\sqrt{z_1^2+z_2^2+z_3^2}$\n",
    "\n",
    "This mapping projects to the surface of the [$S_2$ sphere](https://en.wikipedia.org/wiki/N-sphere) with unit radius. (Why?)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.1: Build and train autoencoder (3D)\n",
    "\n",
    "We start by adding one unit to the bottleneck layer and visualize the latent space in 3D.\n",
    "\n",
    "Please execute the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:04:21.853076Z",
     "iopub.status.busy": "2021-05-25T01:04:21.852507Z",
     "iopub.status.idle": "2021-05-25T01:04:21.865854Z",
     "shell.execute_reply": "2021-05-25T01:04:21.865322Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 3\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, encoding_size),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size, encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "encoder = model[:6]\n",
    "decoder = model[6:]\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.2: Train the autoencoder\n",
    "\n",
    "Train the network for `n_epochs=10` epochs with `batch_size=128`. Observe how the internal representation spreads from the origin and reaches much lower loss due to the additional degree of freedom in the bottleneck layer.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:04:21.881437Z",
     "iopub.status.busy": "2021-05-25T01:04:21.880828Z",
     "iopub.status.idle": "2021-05-25T01:05:05.862769Z",
     "shell.execute_reply": "2021-05-25T01:05:05.861874Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "runSGD(model, input_train, input_test, n_epochs=n_epochs,\n",
    "       batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.3: Visualize the latent space in 3D\n",
    "\n",
    "**Helper function**: `plot_latent_3d`\n",
    "\n",
    "Please uncomment the line below to inspect this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:05.867304Z",
     "iopub.status.busy": "2021-05-25T01:05:05.866102Z",
     "iopub.status.idle": "2021-05-25T01:05:05.867944Z",
     "shell.execute_reply": "2021-05-25T01:05:05.868399Z"
    }
   },
   "outputs": [],
   "source": [
    "# help(plot_latent_3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:05.876759Z",
     "iopub.status.busy": "2021-05-25T01:05:05.876199Z",
     "iopub.status.idle": "2021-05-25T01:05:06.184296Z",
     "shell.execute_reply": "2021-05-25T01:05:06.183792Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_3d(latent_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 2: Build deep autoencoder (2D) with latent spherical space\n",
    "We now constrain the latent space to the surface of a sphere $S_2$.\n",
    "\n",
    "\n",
    "**Instructions:**\n",
    "* Add the custom layer `NormalizeLayer` after the bottleneck layer\n",
    "* Adjust the definitions of `encoder` and `decoder`\n",
    "* Experiment with keyword `show_text=False` for `plot_latent_3d`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Helper function**: `NormalizeLayer`\n",
    "\n",
    "Please uncomment the line below to inspect this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:06.189760Z",
     "iopub.status.busy": "2021-05-25T01:05:06.189169Z",
     "iopub.status.idle": "2021-05-25T01:05:06.193389Z",
     "shell.execute_reply": "2021-05-25T01:05:06.192922Z"
    }
   },
   "outputs": [],
   "source": [
    "# help(NormalizeLayer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:06.199849Z",
     "iopub.status.busy": "2021-05-25T01:05:06.199244Z",
     "iopub.status.idle": "2021-05-25T01:05:06.213044Z",
     "shell.execute_reply": "2021-05-25T01:05:06.212556Z"
    }
   },
   "outputs": [],
   "source": [
    "encoding_size = 3\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, encoding_size),\n",
    "    nn.PReLU(),\n",
    "    #################################################\n",
    "    ## TODO for students: add custom normalize layer\n",
    "    #################################################\n",
    "    # add the normalization layer\n",
    "    # ...,\n",
    "    nn.Linear(encoding_size, encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}\\n')\n",
    "\n",
    "# Adjust the value n_l to split your model correctly\n",
    "# n_l = ...\n",
    "\n",
    "# uncomment when you fill the code\n",
    "# encoder = model[:n_l]\n",
    "# decoder = model[n_l:]\n",
    "# print(f'Encoder \\n\\n {encoder}\\n')\n",
    "# print(f'Decoder \\n\\n {decoder}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:06.219115Z",
     "iopub.status.busy": "2021-05-25T01:05:06.218582Z",
     "iopub.status.idle": "2021-05-25T01:05:06.231412Z",
     "shell.execute_reply": "2021-05-25T01:05:06.231860Z"
    }
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "encoding_size = 3\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(input_size, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, encoding_size),\n",
    "    nn.PReLU(),\n",
    "    # add the normalization layer\n",
    "    NormalizeLayer(),\n",
    "    nn.Linear(encoding_size, encoding_size * 32),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(encoding_size * 32, int(input_size / 2)),\n",
    "    nn.PReLU(),\n",
    "    nn.Linear(int(input_size / 2), input_size),\n",
    "    nn.Sigmoid()\n",
    "    )\n",
    "\n",
    "model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "print(f'Autoencoder \\n\\n {model}\\n')\n",
    "\n",
    "# Adjust the value n_l to split your model correctly\n",
    "n_l = 7\n",
    "\n",
    "# uncomment when you fill the code\n",
    "encoder = model[:n_l]\n",
    "decoder = model[n_l:]\n",
    "print(f'Encoder \\n\\n {encoder}\\n')\n",
    "print(f'Decoder \\n\\n {decoder}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.4: Train the autoencoder\n",
    "Train the network for `n_epochs=10` epochs with `batch_size=128` and observe how loss raises again and is comparable to the model with 2D latent space.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:06.236337Z",
     "iopub.status.busy": "2021-05-25T01:05:06.235827Z",
     "iopub.status.idle": "2021-05-25T01:05:52.172462Z",
     "shell.execute_reply": "2021-05-25T01:05:52.172870Z"
    }
   },
   "outputs": [],
   "source": [
    "n_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "runSGD(model, input_train, input_test, n_epochs=n_epochs,\n",
    "       batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:52.179390Z",
     "iopub.status.busy": "2021-05-25T01:05:52.178829Z",
     "iopub.status.idle": "2021-05-25T01:05:52.269062Z",
     "shell.execute_reply": "2021-05-25T01:05:52.269521Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  latent_test = encoder(input_test)\n",
    "\n",
    "plot_latent_3d(latent_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3.5: Visualize latent space on surface of $S_2$\n",
    "The 3D coordinates $(s_1, s_2, s_3)$ on the surface of the unit sphere $S_2$  can be mapped to [spherical coordinates](https://en.wikipedia.org/wiki/Spherical_coordinate_system) $(r, \\theta, \\phi)$, as follows:\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "r &= \\sqrt{s_1^2 + s_2^2 + s_3^2} \\\\\n",
    "\\phi &= \\arctan \\frac{s_2}{s_1} \\\\\n",
    "\\theta &= \\arccos\\frac{s_3}{r}\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "![Spherical coordinates](https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders/spherical_coords.png)\n",
    "\n",
    "What is the domain (numerical range) spanned by ($\\theta, \\phi)$?\n",
    "\n",
    "We return to a 2D representation since the angles $(\\theta, \\phi)$ are the only degrees of freedom on the surface of the sphere. Add the keyword `s2=True` to `plot_latent_generative` to un-wrap the sphere's surface similar to a world map.\n",
    "\n",
    "Task: Check the numerical range of the plot axis to help identify $\\theta$ and $\\phi$, and visualize the unfolding of the 3D plot from the previous exercise.\n",
    "\n",
    "**Instructions:**\n",
    "* Please execute the cells below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:52.280676Z",
     "iopub.status.busy": "2021-05-25T01:05:52.280139Z",
     "iopub.status.idle": "2021-05-25T01:05:55.464630Z",
     "shell.execute_reply": "2021-05-25T01:05:55.465106Z"
    }
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "  output_test = model(input_test)\n",
    "\n",
    "plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "         image_shape=image_shape)\n",
    "\n",
    "plot_latent_generative(latent_test, y_test, decoder,\n",
    "                       image_shape=image_shape, s2=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Summary\n",
    "We learned two techniques to improve representation capacity: adding a few hidden layers and projecting latent space on the sphere $S_2$.\n",
    "\n",
    "The expressive power of autoencoder improves with additional hidden layers. Projecting latent space on the surface of $S_2$ spreads out digits classes in a more visually pleasing way but may not always produce a lower loss.\n",
    "\n",
    "**Deep autoencoder architectures have rich internal representations to deal with sophisticated tasks such as the MNIST cognitive task.**\n",
    "\n",
    "We now have powerful tools to explore how simple algorithms build robust models of the world by capturing relevant data patterns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:55.470610Z",
     "iopub.status.busy": "2021-05-25T01:05:55.470051Z",
     "iopub.status.idle": "2021-05-25T01:05:55.523519Z",
     "shell.execute_reply": "2021-05-25T01:05:55.523997Z"
    }
   },
   "outputs": [],
   "source": [
    "# @title Video 2: Wrap-up\n",
    "from IPython.display import YouTubeVideo\n",
    "video = YouTubeVideo(id=\"GnkmzCqEK3E\", width=854, height=480, fs=1)\n",
    "print(\"Video available at https://youtube.com/watch?v=\" + video.id)\n",
    "video"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "# Bonus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deep and thick autoencoder\n",
    "In this exercise, we first expand the first hidden layer to double the input size, followed by compression to half the input size leading to 3.8M parameters. Please **do not train this network during tutorial** due to long training time.\n",
    "\n",
    "**Instructions:**\n",
    "* Please uncomment and execute the cells below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:55.529042Z",
     "iopub.status.busy": "2021-05-25T01:05:55.527767Z",
     "iopub.status.idle": "2021-05-25T01:05:55.529633Z",
     "shell.execute_reply": "2021-05-25T01:05:55.530104Z"
    }
   },
   "outputs": [],
   "source": [
    "# encoding_size = 3\n",
    "\n",
    "# model = nn.Sequential(\n",
    "#     nn.Linear(input_size, int(input_size * 2)),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(int(input_size * 2), int(input_size / 2)),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(int(input_size / 2), encoding_size * 32),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(encoding_size * 32, encoding_size),\n",
    "#     nn.PReLU(),\n",
    "#     NormalizeLayer(),\n",
    "#     nn.Linear(encoding_size, encoding_size * 32),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(encoding_size * 32, int(input_size / 2)),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(int(input_size / 2), int(input_size * 2)),\n",
    "#     nn.PReLU(),\n",
    "#     nn.Linear(int(input_size * 2), input_size),\n",
    "#     nn.Sigmoid()\n",
    "#     )\n",
    "\n",
    "# model[:-2].apply(init_weights_kaiming_normal)\n",
    "\n",
    "# encoder = model[:9]\n",
    "# decoder = model[9:]\n",
    "\n",
    "# print_parameter_count(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-05-25T01:05:55.534848Z",
     "iopub.status.busy": "2021-05-25T01:05:55.533776Z",
     "iopub.status.idle": "2021-05-25T01:05:55.536294Z",
     "shell.execute_reply": "2021-05-25T01:05:55.535827Z"
    }
   },
   "outputs": [],
   "source": [
    "# n_epochs = 5\n",
    "# batch_size = 128\n",
    "\n",
    "# runSGD(model, input_train, input_test, n_epochs=n_epochs,\n",
    "#        batch_size=batch_size)\n",
    "\n",
    "# Visualization\n",
    "# with torch.no_grad():\n",
    "#   output_test = model(input_test)\n",
    "\n",
    "# plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],\n",
    "#          image_shape=image_shape)\n",
    "\n",
    "# plot_latent_generative(latent_test, y_test, decoder,\n",
    "#                        image_shape=image_shape, s2=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "JgF7_zvb8d0C"
   ],
   "include_colab_link": true,
   "name": "Tutorial2",
   "provenance": [],
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
